{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4.2 模型参数的访问、初始化和共享"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.4.1\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch.nn import init\n",
    "\n",
    "print(torch.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequential(\n",
      "  (0): Linear(in_features=4, out_features=3, bias=True)\n",
      "  (1): ReLU()\n",
      "  (2): Linear(in_features=3, out_features=1, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "net = nn.Sequential(nn.Linear(4, 3), nn.ReLU(), nn.Linear(3, 1))  # pytorch已进行默认初始化\n",
    "\n",
    "print(net)\n",
    "X = torch.rand(2, 4)\n",
    "Y = net(X).sum()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.2.1 访问模型参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'generator'>\n",
      "0.weight torch.Size([3, 4])\n",
      "0.bias torch.Size([3])\n",
      "2.weight torch.Size([1, 3])\n",
      "2.bias torch.Size([1])\n"
     ]
    }
   ],
   "source": [
    "print(type(net.named_parameters()))\n",
    "for name, param in net.named_parameters():\n",
    "    print(name, param.size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "weight torch.Size([3, 4]) <class 'torch.nn.parameter.Parameter'>\n",
      "bias torch.Size([3]) <class 'torch.nn.parameter.Parameter'>\n"
     ]
    }
   ],
   "source": [
    "for name, param in net[0].named_parameters():\n",
    "    print(name, param.size(), type(param))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "weight1\n"
     ]
    }
   ],
   "source": [
    "class MyModel(nn.Module):\n",
    "    def __init__(self, **kwargs):\n",
    "        super(MyModel, self).__init__(**kwargs)\n",
    "        self.weight1 = nn.Parameter(torch.rand(20, 20))\n",
    "        self.weight2 = torch.rand(20, 20)\n",
    "    def forward(self, x):\n",
    "        pass\n",
    "    \n",
    "n = MyModel()\n",
    "for name, param in n.named_parameters():\n",
    "    print(name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 0.2719, -0.0898, -0.2462,  0.0655],\n",
      "        [-0.4669, -0.2703,  0.3230,  0.2067],\n",
      "        [-0.2708,  0.1171, -0.0995,  0.3913]])\n",
      "None\n",
      "tensor([[-0.2281, -0.0653, -0.1646, -0.2569],\n",
      "        [-0.1916, -0.0549, -0.1382, -0.2158],\n",
      "        [ 0.0000,  0.0000,  0.0000,  0.0000]])\n"
     ]
    }
   ],
   "source": [
    "weight_0 = list(net[0].parameters())[0]\n",
    "print(weight_0.data)\n",
    "print(weight_0.grad)\n",
    "Y.backward()\n",
    "print(weight_0.grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.2.2 初始化模型参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.weight tensor([[ 0.0030,  0.0094,  0.0070, -0.0010],\n",
      "        [ 0.0001,  0.0039,  0.0105, -0.0126],\n",
      "        [ 0.0105, -0.0135, -0.0047, -0.0006]])\n",
      "2.weight tensor([[-0.0074,  0.0051,  0.0066]])\n"
     ]
    }
   ],
   "source": [
    "for name, param in net.named_parameters():\n",
    "    if 'weight' in name:\n",
    "        init.normal_(param, mean=0, std=0.01)\n",
    "        print(name, param.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.bias tensor([0., 0., 0.])\n",
      "2.bias tensor([0.])\n"
     ]
    }
   ],
   "source": [
    "for name, param in net.named_parameters():\n",
    "    if 'bias' in name:\n",
    "        init.constant_(param, val=0)\n",
    "        print(name, param.data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.2.3 自定义初始化方法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def init_weight_(tensor):\n",
    "    with torch.no_grad():\n",
    "        tensor.uniform_(-10, 10)\n",
    "        tensor *= (tensor.abs() >= 5).float()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.weight tensor([[ 7.0403,  0.0000, -9.4569,  7.0111],\n",
      "        [-0.0000, -0.0000,  0.0000,  0.0000],\n",
      "        [ 9.8063, -0.0000,  0.0000, -9.7993]])\n",
      "2.weight tensor([[-5.8198,  7.7558, -5.0293]])\n"
     ]
    }
   ],
   "source": [
    "for name, param in net.named_parameters():\n",
    "    if 'weight' in name:\n",
    "        init_weight_(param)\n",
    "        print(name, param.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.bias tensor([1., 1., 1.])\n",
      "2.bias tensor([1.])\n"
     ]
    }
   ],
   "source": [
    "for name, param in net.named_parameters():\n",
    "    if 'bias' in name:\n",
    "        param.data += 1\n",
    "        print(name, param.data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.2.4 共享模型参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequential(\n",
      "  (0): Linear(in_features=1, out_features=1, bias=False)\n",
      "  (1): Linear(in_features=1, out_features=1, bias=False)\n",
      ")\n",
      "0.weight tensor([[3.]])\n"
     ]
    }
   ],
   "source": [
    "linear = nn.Linear(1, 1, bias=False)\n",
    "net = nn.Sequential(linear, linear) \n",
    "print(net)\n",
    "for name, param in net.named_parameters():\n",
    "    init.constant_(param, val=3)\n",
    "    print(name, param.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "print(id(net[0]) == id(net[1]))\n",
    "print(id(net[0].weight) == id(net[1].weight))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(9., grad_fn=<SumBackward0>)\n",
      "tensor([[6.]])\n"
     ]
    }
   ],
   "source": [
    "x = torch.ones(1, 1)\n",
    "y = net(x).sum()\n",
    "print(y)\n",
    "y.backward()\n",
    "print(net[0].weight.grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
